#include <UnitTest++.h>
#include "config.h"
#include "grid_factory.h"

using namespace mcrx;
using namespace std;

struct griddata {
  double i_;

  griddata() : i_(0) {};
  griddata(double i) : i_(i) {};

  static griddata unification(griddata sum, int n) {
    return griddata(sum.i_); };

  griddata& operator+=(const griddata rhs) {
    i_+=rhs.i_; return *this; };
  griddata operator*(const griddata rhs) {
    griddata temp(*this); temp.i_*=rhs.i_; return temp; };

};

griddata& assign(griddata& lhs, const griddata& rhs) {lhs=rhs; return lhs; };

typedef adaptive_grid<griddata> T_grid;
typedef T_grid::T_grid T_grid_impl;
typedef typename T_grid::T_cell T_cell;
typedef typename T_grid::T_cell_tracker T_cell_tracker;
typedef typename T_grid::T_qpoint T_qpoint;
typedef typename T_cell_tracker::T_code T_code;

struct trivial_factory {
  typedef griddata T_data;
  typedef refinement_accuracy_data<T_data> T_racc;

  int maxlevel;
  int n_threads_;

  trivial_factory(int ml, int nt=1) : maxlevel(ml), n_threads_(nt) {};

  bool refine_cell_p (const T_cell_tracker& c) {
    return c.code().level()<maxlevel; };
  T_data get_data (const T_cell_tracker& c) {
    return 1; };
  bool unify_cell_p (const T_cell_tracker& c, const T_racc& racc) {
    return false; };
  int n_threads () { return n_threads_; };
};

// test that refinement works
TEST(refinement1)
{
  vec3d min(0,0,0);
  vec3d max(1,1,1);

  trivial_factory f(1);

  T_grid g(min, max, f);
  
  T_cell_tracker c(g.begin()), e(g.end());
  while(c!=e) {
    CHECK_EQUAL(1, c.data()->i_);
    ++c;
  }
  CHECK_EQUAL(8, g.n_cells());
}

TEST(refinement2)
{
  vec3d min(0,0,0);
  vec3d max(1,1,1);

  trivial_factory f(2);

  T_grid g(min, max, f);
  
  T_cell_tracker c(g.begin()), e(g.end());
  while(c!=e) {
    CHECK_EQUAL(1, c.data()->i_);
    ++c;
  }
  CHECK_EQUAL(64, g.n_cells());
}



struct trivial_factory2 {
  typedef griddata T_data;
  typedef refinement_accuracy_data<T_data> T_racc;

  int maxlevel;
  int n_threads_;

  trivial_factory2(int ml, int nt=1) : maxlevel(ml), n_threads_(nt) {};

  bool refine_cell_p (const T_cell_tracker& c) {
    return c.code().level()<maxlevel; };
  T_data get_data (const T_cell_tracker& c) {
    return 1; };
  bool unify_cell_p (const T_cell_tracker& c, const T_racc& racc) {
    return true; };
  int n_threads () { return n_threads_; };
};

// test that refinement and subsequent unification works
TEST(unification0)
{
  vec3d min(0,0,0);
  vec3d max(1,1,1);

  trivial_factory2 f(1);

  T_grid g(min, max, f);

  T_cell_tracker c(g.begin()), e(g.end());
  while(c!=e) {
    CHECK_EQUAL(8, c.data()->i_);
    ++c;
  }
  CHECK_EQUAL(1, g.n_cells());
}

TEST(unification1)
{
  vec3d min(0,0,0);
  vec3d max(1,1,1);

  trivial_factory2 f(2);

  T_grid g(min, max, f);

  T_cell_tracker c(g.begin()), e(g.end());
  while(c!=e) {
    CHECK_EQUAL(64, c.data()->i_);
    ++c;
  }
  CHECK_EQUAL(1, g.n_cells());
}


struct radius_factory {
  typedef griddata T_data;
  typedef refinement_accuracy_data<T_data> T_racc;

  bool refine_cell_p (const T_cell_tracker& c) {
    const vec3d center(c.getcenter());
    const vec3d size(c.getsize());
    return (c.code().level()<4) && (mag(size)/mag(center)>0.5); };
  T_data get_data (const T_cell_tracker& c) {
    return mag(c.getcenter()); };
  bool unify_cell_p (const T_cell_tracker& c, const T_racc& racc) {
    return false; };
  int n_threads () { return 1; };
};

// test a simple refinement based on distance from origin
TEST(radius_factory)
{
  vec3d min(0,0,0);
  vec3d max(1,1,1);

  radius_factory f;

  T_grid g(min, max, f);

  T_cell_tracker c(g.begin()), e(g.end());
  while(c!=e) {
    CHECK_CLOSE(mag(c.getcenter()), c.data()->i_, 1e-10);
    CHECK(mag(c.getsize())/mag(c.getcenter())<=0.5 || c.code().level()==4);
    ++c;
  }
}

// multiple threads
TEST(threads1)
{
  vec3d min(0,0,0);
  vec3d max(1,1,1);

  trivial_factory f(2,2);

  T_grid g(min, max, f);
  
  T_cell_tracker c(g.begin()), e(g.end());
  while(c!=e) {
    CHECK_EQUAL(1, c.data()->i_);
    ++c;
  }
  CHECK_EQUAL(64, g.n_cells());
}
  

TEST(threads2)
{
  vec3d min(0,0,0);
  vec3d max(1,1,1);

  trivial_factory f(4,4);

  T_grid g(min, max, f);
  
  T_cell_tracker c(g.begin()), e(g.end());
  while(c!=e) {
    CHECK_EQUAL(1, c.data()->i_);
    ++c;
  }
  CHECK_EQUAL(1<<(3*4), g.n_cells());
}

TEST(threads3)
{
  vec3d min(0,0,0);
  vec3d max(1,1,1);

  trivial_factory2 f(4,4);

  T_grid g(min, max, f);

  T_cell_tracker c(g.begin()), e(g.end());
  while(c!=e) {
    CHECK_EQUAL(1<<(3*4), c.data()->i_);
    ++c;
  }
  CHECK_EQUAL(1, g.n_cells());
}

// really test for race conditions here
TEST(threads4)
{
  vec3d min(0,0,0);
  vec3d max(1,1,1);
  const int maxlevel=7;

  trivial_factory2 f(maxlevel,12);

  T_grid g(min, max, f);

  T_cell_tracker c(g.begin()), e(g.end());
  while(c!=e) {
    CHECK_EQUAL(1<<(3*maxlevel), c.data()->i_);
    ++c;
  }
  CHECK_EQUAL(1, g.n_cells());
}

